{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Example for using light module with pytorch\n",
    "\n",
    "In this notebook, we will show how to train and evaluate a network in 5 minutes with friendly log information to both screen and local file with pytorch.\n",
    "\n",
    "This example is organized in the following five parts:\n",
    "\n",
    "1. data preparation\n",
    "2. network definition\n",
    "3. configuration\n",
    "4. training\n",
    "5. additional\n",
    "\n",
    "## Problem Statement and Data Preparation\n",
    "\n",
    "We start from a classic learning problem `recognizing hand-written digits` and use sklearn to get dataset,\n",
    "where the description of problem and dataset could be found in [here](https://scikit-learn.org/stable/auto_examples/classification/plot_digits_classification.html#sphx-glr-auto-examples-classification-plot-digits-classification-py)."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [
    {
     "data": {
      "text/plain": "<Figure size 720x216 with 4 Axes>",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjwAAACXCAYAAAARS4GeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAALBUlEQVR4nO3dX2yd510H8O+vi8ooW2tnE0wU1sSdBAK0mqZTmZBQqjnSuJgcMRJNDDRXmhJxA5G4cG5gjsZQghByxYYWEGoZMFgjIJ2QCmq0uqMXgGLhTipsF2lamNikQp1uHfsjwcvFcUbUpmnzvufkxE8+HymSz+n5vs9j95dzvnlfH7u6rgsAQMtumvYGAAAmTeEBAJqn8AAAzVN4AIDmKTwAQPMUHgCgeU0Xnqp6tKo+NO7HcmMxRwxlhhgHczRMXW8/h6eqXrrk5i1Jvp3kf7ZuH+667s+u/a7Gq6rek+QTSd6e5B+TLHVd99x0d9WW1ueoqm5O8ukk9yS5I8l9XdetTXVTjbkBZuinknw0yZ6MPq+1JL/Sdd1Xprmv1twAc/RjST6V5M6tu9YzmqN/md6uLu+6O8PTdd2bLv5J8m9J3nfJfd8djKraMb1d9ldVb03yV0l+PcnOJGeTfGaqm2pQ63O05ckkv5jkq9PeSItugBmaTfIHSXZlVJq/nuTBaW6oRTfAHP1Hkp/P6PXsrUk+m+QvprqjV3HdFZ5XU1V7q+rLVbVcVV9N8mBVzVbV31TV81W1ufXxD12SWauqD299vFRVT1bV72w99nxV/WzPx+6uqs9X1der6kxVfaKq/vR1fio/l+TprutOdV33rSQrSe6qqh8d/lXitbQyR13XfafrutWu657M//9rkWugoRl6dOt56Gtd1/13ko8n+ekxfZl4DQ3N0YWu657tRpeLKqPno3eM56s0Xtum8Gx5W0Yt8o4khzLa/4Nbt9+e5JsZ/aV9Nfcm+VJGLfS3k/xRVVWPx346yT8leUtGheWXLg1W1Req6hde5bg/nuSpize6rvtGknNb93NttDBHTFeLM/QzSZ5+nY9lPJqZo6q6kORbSX4vyW9d6bHTst1Oof1vko90XfftrdvfTPKXF/9jVX0syeNXyD/Xdd0fbj32j5P8fpIfyOUvCVz2sTX63ol3JXlP13XfSfJkVX320mDXde+8wh7elOT5l933YpI3XyHDeLUwR0xXUzNUVe9M8htJFl/P4xmbZuao67qZqvq+JB9Kcl1+T+p2O8Pz/NZloCRJVd1SVSer6rmq+lqSzyeZqao3vEr+u0OwdQo3GRWQq3nsDyZ54ZL7kuTfr+JzeCnJrS+779aMrp9zbbQwR0xXMzNUVe9I8miSX+267u+vNs8gzczR1nG/keSTST5VVd/f5xiTtN0Kz8vfUvZrSX4kyb1d192a0SnZZHQdcVK+kmRnVd1yyX0/fBX5p5PcdfHGViO+M04lX0stzBHT1cQMVdUdSc4k+WjXdX8yzs3xujQxRy9zU0bvRrt90K4mYLsVnpd7c0anAC9U1c4kH5n0gltvHz+bZKWqbq6qdyd531Uc4q+T/ERVvb+q3pjRaeQvdF33xQlsl9dnO85Rqup7tmYoSW6uqjde4fo9k7XtZqiqbk/yuSQf77rukxPaJldnO87Rvqr6yap6Q1XdmuR3k2wm+dfJ7Li/7V54VpN8b5L/TPIPSf72Gq37wSTvTvJfSX4zo7eVX7wGm6p6uqo+eLlg13XPJ3l/ko9lNBT3JvnApDfMFa1mm83Rli9l9OR4e5K/2/r4jontlitZzfaboQ8nmcvohe6li38mvWGuaDXbb45mkvx5Rt+Lei6jKxbvvfRS3fXiuvvBg9tRVX0myRe7rpt4G6dd5oihzBDj0OocbfczPFNRVe+qqjur6qaqem9G72w4PeVtsc2YI4YyQ4zDjTJH2+1t6deLt2X005LfkuTLSX6567p/nu6W2IbMEUOZIcbhhpgjl7QAgOa5pAUANO+1LmlN5fTPqVOnBuWXl5d7Z/ft29c7e/z48d7Z2dnZ3tkxmPRbmbflacS9e/f2zl64cKF39tixY72zi4tT/UG5k5yjbTlDa2trvbP79+/vnZ2fn++dHbLnMWjyuejEiROD8kePHu2d3b17d+/s+vp67+z1+JrmDA8A0DyFBwBonsIDADRP4QEAmqfwAADNU3gAgOYpPABA8xQeAKB5Cg8A0DyFBwBonsIDADRP4QEAmqfwAADNU3gAgObtmPYGLmd5eXlQ/vz5872zm5ubvbM7d+7snX344Yd7Z5PkwIEDg/K80szMTO/sE0880Tv7+OOP984uLi72zvJKGxsbg/L33Xdf7+xtt93WO/vss8/2znJ5R48e7Z0d+vx+8uTJ3tnDhw/3zq6vr/fOLiws9M5OijM8AEDzFB4AoHkKDwDQPIUHAGiewgMANE/hAQCap/AAAM1TeACA5ik8AEDzFB4AoHkKDwDQPIUHAGiewgMANE/hAQCat2NSBx7ya+XPnz8/aO1z5871zs7NzfXO7tu3r3d2yNcrSQ4cODAo36KNjY1B+bW1tbHs42rNz89PZV1e6fTp04Pyd911V+/s/v37e2ePHTvWO8vlHTp0qHd2eXl50Np79uzpnd29e3fv7MLCQu/s9cgZHgCgeQoPANA8hQcAaJ7CAwA0T+EBAJqn8AAAzVN4AIDmKTwAQPMUHgCgeQoPANA8hQcAaJ7CAwA0T+EBAJqn8AAAzVN4AIDm7ZjUgTc3N3tn77777kFrz83NDcr3tWfPnqms27LV1dXe2ZWVlUFrv/jii4Pyfe3du3cq6/JKR44cGZTftWvXVNZeXFzsneXyhryuPPPMM4PWPn/+fO/swsJC7+yQ1/HZ2dne2UlxhgcAaJ7CAwA0T+EBAJqn8AAAzVN4AIDmKTwAQPMUHgCgeQoPANA8hQcAaJ7CAwA0T+EBAJqn8AAAzVN4AIDmKTwAQPN2TOrAQ36t/L59+8a4k2tnyOc8Ozs7xp2048iRI72zS0tLg9ae1v+TCxcuTGXdVg35eq6urg5a+/Tp04PyfT300ENTWZfLm5ubG5R/4YUXemcXFhamkj1z5kzvbDKZ519neACA5ik8AEDzFB4AoHkKDwDQPIUHAGiewgMANE/hAQCap/AAAM1TeACA5ik8AEDzFB4AoHkKDwDQPIUHAGiewgMANG/HpA485Fe7r6+vj3EnV2dzc7N39uzZs72zBw8e7J2lLRsbG72z8/PzY9tHK1ZWVnpnH3jggfFt5CqdPn26d3ZmZmZs+2D6hryenjlzpnf28OHDvbMnTpzonU2S48ePD8pfjjM8AEDzFB4AoHkKDwDQPIUHAGiewgMANE/hAQCap/AAAM1TeACA5ik8AEDzFB4AoHkKDwDQPIUHAGiewgMANE/hAQCat2NSB56bm+udPXv27KC1T506NZXsEMvLy1NZF1q3tLTUO7u2tjZo7aeeeqp3dv/+/b2zi4uLvbP3339/7+zQtVt19OjRQfmFhYXe2c3Nzd7Zxx57rHf24MGDvbOT4gwPANA8hQcAaJ7CAwA0T+EBAJqn8AAAzVN4AIDmKTwAQPMUHgCgeQoPANA8hQcAaJ7CAwA0T+EBAJqn8AAAzVN4AIDmKTwAQPN2TOrAc3NzvbMnTpwYtPby8nLv7D333NM7u76+3jvL+M3MzAzKLy4u9s4+8sgjvbNra2u9s0tLS72zrZqfn++d3djYGLT2kPzKykrv7JD527VrV+9sMuzvTatmZ2cH5Q8dOjSmnVydgwcP9s6ePHlyjDsZD2d4AIDmKTwAQPMUHgCgeQoPANA8hQcAaJ7CAwA0T+EBAJqn8AAAzVN4AIDmKTwAQPMUHgCgeQoPANA8hQcAaJ7CAwA0r7qum/YeAAAmyhkeAKB5Cg8A0DyFBwBonsIDADRP4QEAmqfwAADN+z+hHt0iyNm/ygAAAABJRU5ErkJggg==\n"
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "from sklearn import datasets\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "digits = datasets.load_digits()\n",
    "\n",
    "_, axes = plt.subplots(nrows=1, ncols=4, figsize=(10, 3))\n",
    "for ax, image, label in zip(axes, digits.images, digits.target):\n",
    "    ax.set_axis_off()\n",
    "    ax.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')\n",
    "    ax.set_title('Training: %i' % label)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [
    {
     "data": {
      "text/plain": "(array([[ 0.,  0.,  5., ...,  0.,  0.,  0.],\n        [ 0.,  0.,  0., ..., 10.,  0.,  0.],\n        [ 0.,  0.,  0., ..., 16.,  9.,  0.],\n        ...,\n        [ 0.,  4., 14., ..., 11.,  0.,  0.],\n        [ 0.,  0.,  0., ...,  5.,  0.,  0.],\n        [ 0.,  0., 12., ...,  8.,  0.,  0.]]),\n array([0, 1, 2, ..., 3, 4, 5]))"
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "# flatten the images\n",
    "n_samples = len(digits.images)\n",
    "data = digits.images.reshape((n_samples, -1))\n",
    "\n",
    "# split dataset into train, validation and test\n",
    "# we set (train, validation): test = 8 : 2 and train : validation = 9 : 1\n",
    "X_train, X_test, y_train, y_test = train_test_split(\n",
    "    data, digits.target, test_size=0.2, shuffle=False)\n",
    "\n",
    "X_train, X_valid, y_train, y_valid = train_test_split(\n",
    "    X_train, y_train, test_size=0.1, shuffle=False)\n",
    "\n",
    "X_train, y_train"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [],
   "source": [
    "# transform data to torch compatible format\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "def transform(x, y, batch_size, **params):\n",
    "    dataset = TensorDataset(\n",
    "        torch.tensor(x, dtype=torch.float32),\n",
    "        torch.tensor(y, dtype=torch.int64)\n",
    "    )\n",
    "    return DataLoader(dataset, batch_size=batch_size, **params)\n",
    "\n",
    "# because the batch_size and some other batch-relevant parameters are\n",
    "# supposed to be determined in configuration,\n",
    "# we will perform the real data transformation after we set the configuration."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Network Definition"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class MLP(nn.Module):\n",
    "    def __init__(self, in_feats, hidden_layers=None, out=10, **kwargs):\n",
    "        super(MLP, self).__init__()\n",
    "        layers = []\n",
    "        if hidden_layers is not None:\n",
    "            for i, units in enumerate([in_feats] + hidden_layers[:-1]):\n",
    "                layers.append(nn.Linear(units, hidden_layers[i]))\n",
    "                layers.append(nn.Dropout(0.5))\n",
    "        layers.append(nn.Linear(hidden_layers[-1], out))\n",
    "        self.nn = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return F.log_softmax(self.nn(x), dim=-1)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Configuration"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [
    {
     "data": {
      "text/plain": "logger: <Logger mlp (INFO)>\nmodel_name: mlp\nmodel_dir: mlp\nbegin_epoch: 0\nend_epoch: 2\nbatch_size: 32\nsave_epoch: 1\noptimizer: Adam\noptimizer_params: {'lr': 0.001, 'weight_decay': 0.0001}\nlr_params: {}\ntrain_select: None\nsave_select: None\nctx: cpu\ntoolbox_params: {}\nhyper_params: {'hidden_layers': [512, 128]}\ninit_params: {}\nloss_params: {}\ncaption: \nvalidation_result_file: mlp\\result.json\ncfg_path: mlp\\configuration.json"
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from baize.torch import Configuration\n",
    "\n",
    "configuration = Configuration(model_name=\"mlp\", model_dir=\"mlp\")\n",
    "configuration.end_epoch = 2\n",
    "configuration.batch_size = 32\n",
    "configuration.hyper_params = {\"hidden_layers\": [512, 128]}\n",
    "\n",
    "configuration"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Training"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Loss function"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [
    {
     "data": {
      "text/plain": "CrossEntropyLoss()"
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss_f = torch.nn.CrossEntropyLoss()\n",
    "loss_f"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Training Procedure"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [],
   "source": [
    "def fit_f(_net, batch_data, trainer, loss_function, ctx=\"cpu\", *args, **kwargs):\n",
    "    x, y = batch_data\n",
    "    x = x.to(ctx)\n",
    "    y = y.to(ctx)\n",
    "    out = _net(x)\n",
    "    loss = loss_function(out, y)\n",
    "    trainer.zero_grad()\n",
    "    loss.backward()\n",
    "    trainer.step()\n",
    "    return loss"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Light Module"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [
    {
     "data": {
      "text/plain": "<torch.utils.data.dataloader.DataLoader at 0x2397a427b50>"
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# transform data to torch compatible format\n",
    "train_data = transform(X_train, y_train, configuration.batch_size)\n",
    "train_data"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [
    {
     "data": {
      "text/plain": "MLP(\n  (nn): Sequential(\n    (0): Linear(in_features=64, out_features=512, bias=True)\n    (1): Dropout(p=0.5, inplace=False)\n    (2): Linear(in_features=512, out_features=128, bias=True)\n    (3): Dropout(p=0.5, inplace=False)\n    (4): Linear(in_features=128, out_features=10, bias=True)\n  )\n)"
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Then set the network and initialize it\n",
    "net = MLP(in_feats=64, **configuration.hyper_params)\n",
    "net"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "outputs": [
    {
     "data": {
      "text/plain": "Adam (\nParameter Group 0\n    amsgrad: False\n    betas: (0.9, 0.999)\n    eps: 1e-08\n    lr: 0.001\n    weight_decay: 0\n)"
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Define the trainer (also known as optimizer)\n",
    "from torch.optim import Adam\n",
    "trainer = Adam(net.parameters(), lr=0.001)\n",
    "trainer"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "After defining the necessary components needed in the training stage,\n",
    "we spend a little time reviewing the classical training procedure in mxnet:"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "41it [00:00, 336.65it/s]\n",
      "41it [00:00, 526.30it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 0]SoftmaxCELoss: 1.4920656673791932\n",
      "[Epoch 1]SoftmaxCELoss: 0.46701456333805874\n"
     ]
    }
   ],
   "source": [
    "# An example in classic torch training style\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "for i in range(configuration.begin_epoch, configuration.end_epoch):\n",
    "    losses = []\n",
    "    for j, batch in tqdm(enumerate(train_data)):\n",
    "        loss_v = fit_f(\n",
    "            _net=net,\n",
    "            batch_data=batch,\n",
    "            trainer=trainer,\n",
    "            loss_function=loss_f,\n",
    "            ctx=configuration.ctx,\n",
    "        )\n",
    "        losses.append(loss_v.mean().item())\n",
    "    print(\"[Epoch %d]SoftmaxCELoss: %s\" % (i, np.mean(losses)))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch| Total-E          Batch     Total-B       Loss-CrossEntropyLoss             Progress           \n",
      "    0|       1             41          41                     0.27995    [00:00<00:00, 252.21it/s]   \n",
      "Epoch| Total-E          Batch     Total-B       Loss-CrossEntropyLoss             Progress           \n",
      "    1|       1             41          41                    0.247723    [00:00<00:00, 176.44it/s]   \n"
     ]
    }
   ],
   "source": [
    "from baize.torch import light_module as lm\n",
    "lm.train(\n",
    "    net=net,\n",
    "    loss_function=loss_f,\n",
    "    cfg=configuration,\n",
    "    trainer=trainer,\n",
    "    train_data=train_data,\n",
    "    fit_f=fit_f,\n",
    "    initial_net=False\n",
    ")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Evaluation\n",
    "Then let us try to attach some evaluation procedure during training on the validation dataset:"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "outputs": [],
   "source": [
    "from baize.metrics import classification_report\n",
    "def eval_f(_net, test_data, ctx=\"cpu\"):\n",
    "    _net.eval()\n",
    "    y_true = []\n",
    "    y_pred = []\n",
    "    for x, y in test_data:\n",
    "        x = x.to(ctx)\n",
    "        pred = _net(x).argmax(-1).tolist()\n",
    "        y_pred.extend(pred)\n",
    "        y_true.extend(y.tolist())\n",
    "    _net.train()\n",
    "    return classification_report(y_true, y_pred)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "outputs": [
    {
     "data": {
      "text/plain": "<torch.utils.data.dataloader.DataLoader at 0x2397c0e29d0>"
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "valid_data = transform(X_valid, y_valid, configuration.batch_size)\n",
    "valid_data"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch| Total-E          Batch     Total-B       Loss-CrossEntropyLoss             Progress           \n",
      "    0|       1             41          41                    1.015863    [00:00<00:00, 208.68it/s]   \n",
      "Epoch [0]\tLoss - CrossEntropyLoss: 1.015863\n",
      "           precision    recall        f1  support\n",
      "0           0.937500  1.000000  0.967742       15\n",
      "1           0.588235  0.666667  0.625000       15\n",
      "2           1.000000  0.500000  0.666667       14\n",
      "3           0.764706  0.928571  0.838710       14\n",
      "4           0.928571  0.928571  0.928571       14\n",
      "5           1.000000  0.857143  0.923077       14\n",
      "6           0.866667  0.866667  0.866667       15\n",
      "7           1.000000  1.000000  1.000000       15\n",
      "8           1.000000  0.571429  0.727273       14\n",
      "9           0.608696  1.000000  0.756757       14\n",
      "macro_avg   0.869437  0.831905  0.830046      144\n",
      "accuracy: 0.833333\n",
      "Epoch| Total-E          Batch     Total-B       Loss-CrossEntropyLoss             Progress           \n",
      "    1|       1             41          41                    0.361478    [00:00<00:00, 229.83it/s]   \n",
      "Epoch [1]\tLoss - CrossEntropyLoss: 0.361478\n",
      "           precision    recall        f1  support\n",
      "0           1.000000  1.000000  1.000000       15\n",
      "1           0.750000  1.000000  0.857143       15\n",
      "2           1.000000  0.785714  0.880000       14\n",
      "3           0.823529  1.000000  0.903226       14\n",
      "4           0.933333  1.000000  0.965517       14\n",
      "5           1.000000  0.857143  0.923077       14\n",
      "6           0.928571  0.866667  0.896552       15\n",
      "7           1.000000  1.000000  1.000000       15\n",
      "8           0.916667  0.785714  0.846154       14\n",
      "9           1.000000  0.928571  0.962963       14\n",
      "macro_avg   0.935210  0.922381  0.923463      144\n",
      "accuracy: 0.923611\n"
     ]
    }
   ],
   "source": [
    "net = MLP(in_feats=64, **configuration.hyper_params)\n",
    "\n",
    "trainer = Adam(net.parameters(), lr=0.001)\n",
    "\n",
    "lm.train(\n",
    "    net=net,\n",
    "    loss_function=loss_f,\n",
    "    cfg=configuration,\n",
    "    trainer=trainer,\n",
    "    train_data=train_data,\n",
    "    test_data=valid_data,\n",
    "    fit_f=fit_f,\n",
    "    eval_f=eval_f,\n",
    "    initial_net=False\n",
    ")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "You may want to use tqdm to show the progress:"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch: 0: 100%|██████████| 41/41 [00:00<00:00, 442.07it/s]\n",
      "Epoch: 1: 100%|██████████| 41/41 [00:00<00:00, 456.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [0]\tLoss - CrossEntropyLoss: 0.324993\n",
      "           precision    recall        f1  support\n",
      "0           1.000000  1.000000  1.000000       15\n",
      "1           0.833333  1.000000  0.909091       15\n",
      "2           1.000000  1.000000  1.000000       14\n",
      "3           1.000000  1.000000  1.000000       14\n",
      "4           1.000000  0.714286  0.833333       14\n",
      "5           1.000000  0.928571  0.962963       14\n",
      "6           0.705882  0.800000  0.750000       15\n",
      "7           1.000000  1.000000  1.000000       15\n",
      "8           1.000000  1.000000  1.000000       14\n",
      "9           1.000000  1.000000  1.000000       14\n",
      "macro_avg   0.953922  0.944286  0.945539      144\n",
      "accuracy: 0.944444\n",
      "Epoch [1]\tLoss - CrossEntropyLoss: 0.311763\n",
      "           precision    recall        f1  support\n",
      "0           0.937500  1.000000  0.967742       15\n",
      "1           0.937500  1.000000  0.967742       15\n",
      "2           1.000000  0.785714  0.880000       14\n",
      "3           0.933333  1.000000  0.965517       14\n",
      "4           1.000000  1.000000  1.000000       14\n",
      "5           1.000000  0.928571  0.962963       14\n",
      "6           0.937500  1.000000  0.967742       15\n",
      "7           1.000000  1.000000  1.000000       15\n",
      "8           1.000000  1.000000  1.000000       14\n",
      "9           1.000000  1.000000  1.000000       14\n",
      "macro_avg   0.974583  0.971429  0.971171      144\n",
      "accuracy: 0.972222\n"
     ]
    }
   ],
   "source": [
    "lm.train(\n",
    "    net=net,\n",
    "    loss_function=loss_f,\n",
    "    cfg=configuration,\n",
    "    trainer=trainer,\n",
    "    train_data=train_data,\n",
    "    test_data=valid_data,\n",
    "    fit_f=fit_f,\n",
    "    eval_f=eval_f,\n",
    "    progress_monitor=\"tqdm\",\n",
    "    initial_net=False\n",
    ")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Simplest\n",
    "\n",
    "With some toolkit functions, we can furthermore make the code simpler."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch| Total-E          Batch     Total-B       Loss-cross entropy             Progress           \n",
      "    0|       1             41          41                 0.922297    [00:00<00:00, 250.67it/s]   \n",
      "Epoch [0]\tLoss - cross entropy: 0.922297\n",
      "           precision    recall        f1  support\n",
      "0           0.882353  1.000000  0.937500       15\n",
      "1           0.812500  0.866667  0.838710       15\n",
      "2           1.000000  0.571429  0.727273       14\n",
      "3           0.875000  1.000000  0.933333       14\n",
      "4           0.928571  0.928571  0.928571       14\n",
      "5           1.000000  0.857143  0.923077       14\n",
      "6           0.933333  0.933333  0.933333       15\n",
      "7           1.000000  1.000000  1.000000       15\n",
      "8           1.000000  1.000000  1.000000       14\n",
      "9           0.764706  0.928571  0.838710       14\n",
      "macro_avg   0.919646  0.908571  0.906051      144\n",
      "accuracy: 0.909722\n",
      "Epoch| Total-E          Batch     Total-B       Loss-cross entropy             Progress           \n",
      "    1|       1             41          41                 0.326139    [00:00<00:00, 219.84it/s]   \n",
      "Epoch [1]\tLoss - cross entropy: 0.326139\n",
      "           precision    recall        f1  support\n",
      "0           0.937500  1.000000  0.967742       15\n",
      "1           0.882353  1.000000  0.937500       15\n",
      "2           1.000000  0.785714  0.880000       14\n",
      "3           0.875000  1.000000  0.933333       14\n",
      "4           0.933333  1.000000  0.965517       14\n",
      "5           1.000000  0.857143  0.923077       14\n",
      "6           0.933333  0.933333  0.933333       15\n",
      "7           1.000000  1.000000  1.000000       15\n",
      "8           1.000000  1.000000  1.000000       14\n",
      "9           1.000000  0.928571  0.962963       14\n",
      "macro_avg   0.956152  0.950476  0.950347      144\n",
      "accuracy: 0.951389\n"
     ]
    }
   ],
   "source": [
    "from baize.torch import fit_wrapper, loss_dict2tmt_torch_loss, eval_wrapper\n",
    "\n",
    "@fit_wrapper\n",
    "def fit_f(_net, batch_data, loss_function, *args, **kwargs):\n",
    "    x, y = batch_data\n",
    "    out = _net(x)\n",
    "    loss = []\n",
    "    for _f in loss_function.values():\n",
    "        loss.append(_f(out, y))\n",
    "    return sum(loss)\n",
    "\n",
    "@eval_wrapper\n",
    "def eval_f(_net, test_data, *args, **kwargs):\n",
    "    y_true = []\n",
    "    y_pred = []\n",
    "    for x, y in test_data:\n",
    "        pred = _net(x).argmax(-1).tolist()\n",
    "        y_pred.extend(pred)\n",
    "        y_true.extend(y.tolist())\n",
    "    return classification_report(y_true, y_pred)\n",
    "\n",
    "def get_net(*args, **kwargs):\n",
    "    return MLP(in_feats=64, *args, **kwargs)\n",
    "\n",
    "def get_loss(*args, **kwargs):\n",
    "    return loss_dict2tmt_torch_loss(\n",
    "        {\"cross entropy\": torch.nn.CrossEntropyLoss(*args, **kwargs)}\n",
    "    )\n",
    "\n",
    "lm.train(\n",
    "    net=None,\n",
    "    cfg=configuration,\n",
    "    loss_function=None,\n",
    "    get_loss=get_loss,\n",
    "    trainer=None,\n",
    "    train_data=train_data,\n",
    "    test_data=valid_data,\n",
    "    fit_f=fit_f,\n",
    "    eval_f=eval_f,\n",
    "    initial_net=True,\n",
    "    get_net=get_net,\n",
    ")\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}